{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Seq2Seq(with Attention)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch: 0400 cost = 0.000012\n",
      "Epoch: 0800 cost = 0.067272\n",
      "Epoch: 1200 cost = 0.000000\n",
      "Epoch: 1600 cost = 0.000000\n",
      "Epoch: 2000 cost = 0.000000\n",
      "['ich', 'mochte', 'ein', 'bier', 'P'] -> ['i', 'want', 'a', 'beer', 'E']\n"
     ]
    }
   ],
   "source": [
    "tf.reset_default_graph()\n",
    "#创建真正的翻译seq2seq\n",
    "\n",
    "\n",
    "#P是用来填充长度不够的字符串的\n",
    "#S表示decoder的输入开端\n",
    "#E表示decoder的输出结尾\n",
    "sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']#data\n",
    "words = ' '.join(sentences).split()#使用’ ‘连接list中的每个元素 构建词表\n",
    "words = list(set(words))#去一下重\n",
    "word_num = {w:i for i,w in enumerate(words)}#建立数字和单词的对应关系\n",
    "num_word = {i:w for i,w in enumerate(words)}\n",
    "\n",
    "#定义超参数\n",
    "n_step = 5#句子长度\n",
    "n_hidden = 128#隐层单元数量\n",
    "n_class = len(word_num)#词表长度\n",
    "\n",
    "def get_batch(sentences):\n",
    "    input_batch = [np.eye(n_class)[[word_num[w] for w in sentences[0].split()]]]#[batch_size,n_step,n_class]\n",
    "    output_batch = [np.eye(n_class)[[word_num[w] for w in sentences[1].split()]]]#[batch_size,n_step+1,n_class]\n",
    "    target_batch = [[word_num[w] for w in sentences[2].split()]]#[batch_size,n_step+1]\n",
    "    return input_batch,output_batch,target_batch\n",
    "\n",
    "enc_input = tf.placeholder(tf.float32,[None,None,n_class])#[batch_size,n_step,n_class]\n",
    "dec_input = tf.placeholder(tf.float32,[None,None,n_class])#[batch_size,n_step,n_class]\n",
    "target = tf.placeholder(tf.int64,[1, n_step])##[batch_size,n_step]不是one-hot\n",
    "\n",
    "attn = tf.Variable(tf.random_normal([n_hidden,n_hidden]))\n",
    "out = tf.Variable(tf.random_normal([n_hidden*2,n_class]))\n",
    "\n",
    "def get_att_score(decoder_hidden,encoder_hidden):\n",
    "    #encoder_hidden[batch_size,n_hidden]\n",
    "    score = tf.squeeze(tf.matmul(encoder_hidden,attn),0)#[n_hidden]\n",
    "    decoder_hidden = tf.squeeze(decoder_hidden,[0,1])#[n_hidden]\n",
    "    return tf.tensordot(decoder_hidden,score,1)\n",
    "    \n",
    "def get_att_weight(decoder_hidden,encoder_hidden):\n",
    "    attn_score = []\n",
    "    encoder_hidden = tf.transpose(encoder_hidden,[1,0,2])#[n_step,batch_size,n_hidden]\n",
    "    for i in range(n_step):\n",
    "        attn_score.append(get_att_score(decoder_hidden,encoder_hidden[i]))\n",
    "    return tf.reshape(tf.nn.softmax(attn_score),[1,1,-1])# [1, 1, n_step]\n",
    "\n",
    "Attention = []#[n_step,n_step]\n",
    "model = []\n",
    "\n",
    "with tf.variable_scope('encoder'):\n",
    "    encoder = tf.contrib.rnn.BasicRNNCell(n_hidden)\n",
    "    encoder = tf.contrib.rnn.DropoutWrapper(encoder,output_keep_prob=0.5)\n",
    "    encoder_hidden,encoder_output = tf.nn.dynamic_rnn(encoder,enc_input,dtype=tf.float32)\n",
    "    #enc_input [batch_size,n_step,n_class] \n",
    "    #encoder_hidden [batch_size,n_step,n_hidden]\n",
    "    #encoder_output [batch_size,n_hidden]\n",
    "    \n",
    "with tf.variable_scope('decoder'):\n",
    "    decoder = tf.contrib.rnn.BasicRNNCell(n_hidden)\n",
    "    decoder = tf.contrib.rnn.DropoutWrapper(decoder,output_keep_prob=0.5)#dropout 随机失活\n",
    "    inputs = tf.transpose(dec_input,[1,0,2])#Time_major 如果为True 那么表示输入的向量的第一维度必须是time_step\n",
    "    hidden = encoder_output\n",
    "    for i in range(n_step):\n",
    "        #这里的输入只取dec_input的第一项\n",
    "        decoder_hidden,hidden = tf.nn.dynamic_rnn(decoder,tf.expand_dims(inputs[i],1),initial_state=hidden,\n",
    "                                                          dtype=tf.float32,time_major = True)#tf.expand_dims 在指定的地方增加一个维度\n",
    "        att_weights = get_att_weight(decoder_hidden,encoder_hidden)\n",
    "        Attention.append(tf.squeeze(att_weights))\n",
    "        context = tf.matmul(att_weights,encoder_hidden)#[1,1,n_step]x[1,n_step,n_hidden] = [1,1,n_hidden]\n",
    "        decoder_hidden = tf.squeeze(decoder_hidden,0)#[1,n_step]\n",
    "        context = tf.squeeze(context, 1)#[1,n_hidden]\n",
    "        model.append(tf.matmul(tf.concat((decoder_hidden,context),1),out))# [n_step, batch_size(=1), n_class]\n",
    "\n",
    "train_att = tf.stack([Attention[0],Attention[1],Attention[2],Attention[3],Attention[4]],0)\n",
    "model = tf.transpose(model,[1,0,2])# [batch_size(=1),n_step, n_class]\n",
    "prediction = tf.argmax(model,2)\n",
    "cost = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=model, labels=target))\n",
    "optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    init = tf.global_variables_initializer()\n",
    "    sess.run(init)\n",
    "    for epoch in range(2000):\n",
    "        input_batch, output_batch, target_batch = get_batch(sentences)\n",
    "        _, loss, attention = sess.run([optimizer, cost, train_att],\n",
    "                                      feed_dict={enc_input: input_batch, dec_input: output_batch, target: target_batch})\n",
    "\n",
    "        if (epoch + 1) % 400 == 0:\n",
    "            print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))\n",
    "\n",
    "    predict_batch = [np.eye(n_class)[[word_num[n] for n in 'P P P P P'.split()]]]\n",
    "    result = sess.run(prediction, feed_dict={enc_input: input_batch, dec_input: predict_batch})\n",
    "    print(sentences[0].split(), '->', [num_word[n] for n in result[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
